# organise the training data according to the mapping in mappings.txt
# start with a directory full of .tar files
# this will extract everything and organise it into directories
# so that the training code can run on that directory

import os
import random
import shutil
import json
import tarfile
from collections import defaultdict
import pdb

concepts = []
with open('../word2vec/rows.json') as f:
	concepts = json.load(f)

mappings = {}
with open('../imagenet_labels/mapping.txt') as f:
	for line in f.readlines():
		line = line.strip()
		synset = line.split()[0]
		label = line.split()[1]
		if label in concepts:
			mappings[synset] = label

imagedir_in = '/local/filespace-2/fh295/DATA/train'
tar_filenames = [code for code in os.listdir(imagedir_in) if \
                 code.startswith('n') and code.endswith('tar') \
                 and code.split('.')[0] in mappings]

print 'total of %s tar files in the directory' % (len(tar_filenames))

#extract every file
for i,f in enumerate(tar_filenames):
    full_f = os.path.join(imagedir_in,f)
    t = tarfile.open(full_f,'r')
    t.extractall(path=imagedir_in)
    print 'extracted %s tar files' % (i)
    os.remove(full_f)


img_filenames = [f for f in os.listdir(imagedir_in) if \
                 f.endswith('JPEG')  and f.split('_')[0] in mappings]
class_image_dict = defaultdict(list)
pdb.set_trace()

# fill class image dict
for img in img_filenames:
    cls = img.split('_')[0]
    class_image_dict[mappings[cls]].append(img)


min_class_size = 1000 #  min([len(imgs) for cls,imgs in class_image_dict.iteritems()]) 
print 'min class size of training set is %s' % (min_class_size)

# fill up new directories up to this size
# by moving files
for cls,imgs in class_image_dict.items():
    random.shuffle(imgs)
    imgs_chosen = imgs[:min(len(imgs)-1,min_class_size)]
    newdir = os.path.join(imagedir_in,cls)
    if not os.path.exists(newdir):
        os.mkdir(newdir)
        print 'made new directory %s' % (cls)
    for img in imgs:
        full_img = os.path.join(imagedir_in,img)
        if img in imgs_chosen:
            shutil.copy(full_img,newdir)
        os.remove(full_img) 



		






